//
//  Sim3Solver.cpp
//  ORB_SLAM_iOS
//
//  Created by Xin Sun on 14/10/2015.
//  Copyright © 2015 Xin Sun. All rights reserved.
//

#include "Sim3Solver.hpp"
#include <vector>
#include <cmath>
#include <opencv2/opencv.hpp>

#include "KeyFrame.hpp"
#include "ORBmatcher.hpp"

#include "Thirdparty/DBoW2/DUtils/Random.h"

namespace ORB_SLAM
{
    
    
    Sim3Solver::Sim3Solver(KeyFrame *pKF1, KeyFrame *pKF2, const vector<MapPoint *> &vpMatched12):
    mnIterations(0), mnBestInliers(0)
    {
        mpKF1 = pKF1;
        mpKF2 = pKF2;
        
        vector<MapPoint*> vpKeyFrameMP1 = pKF1->GetMapPointMatches();
        
        mN1 = vpMatched12.size();
        
        mvpMapPoints1.reserve(mN1);
        mvpMapPoints2.reserve(mN1);
        mvpMatches12 = vpMatched12;
        mvnIndices1.reserve(mN1);
        mvX3Dc1.reserve(mN1);
        mvX3Dc2.reserve(mN1);
        
        cv::Mat Rcw1 = pKF1->GetRotation();
        cv::Mat tcw1 = pKF1->GetTranslation();
        cv::Mat Rcw2 = pKF2->GetRotation();
        cv::Mat tcw2 = pKF2->GetTranslation();
        
        mvAllIndices.reserve(mN1);
        
        size_t idx=0;
        for(int i1=0; i1<mN1; i1++)
        {
            if(vpMatched12[i1])
            {
                MapPoint* pMP1 = vpKeyFrameMP1[i1];
                MapPoint* pMP2 = vpMatched12[i1];
                
                if(!pMP1)
                    continue;
                
                if(pMP1->isBad() || pMP2->isBad())
                    continue;
                
                int indexKF1 = pMP1->GetIndexInKeyFrame(pKF1);
                int indexKF2 = pMP2->GetIndexInKeyFrame(pKF2);
                
                if(indexKF1<0 || indexKF2<0)
                    continue;
                
                const cv::KeyPoint &kp1 = pKF1->GetKeyPointUn(indexKF1);
                const cv::KeyPoint &kp2 = pKF2->GetKeyPointUn(indexKF2);
                
                const float sigmaSquare1 = pKF1->GetSigma2(kp1.octave);
                const float sigmaSquare2 = pKF2->GetSigma2(kp2.octave);
                
                mvnMaxError1.push_back(9.210*sigmaSquare1);
                mvnMaxError2.push_back(9.210*sigmaSquare2);
                
                mvpMapPoints1.push_back(pMP1);
                mvpMapPoints2.push_back(pMP2);
                mvnIndices1.push_back(i1);
                
                cv::Mat X3D1w = pMP1->GetWorldPos();
                mvX3Dc1.push_back(Rcw1*X3D1w+tcw1);
                
                cv::Mat X3D2w = pMP2->GetWorldPos();
                mvX3Dc2.push_back(Rcw2*X3D2w+tcw2);
                
                mvAllIndices.push_back(idx);
                idx++;
            }
        }
        
        mK1 = pKF1->GetCalibrationMatrix();
        mK2 = pKF2->GetCalibrationMatrix();
        
        FromCameraToImage(mvX3Dc1,mvP1im1,mK1);
        FromCameraToImage(mvX3Dc2,mvP2im2,mK2);
        
        SetRansacParameters();
    }
    
    void Sim3Solver::SetRansacParameters(double probability, int minInliers, int maxIterations)
    {
        mRansacProb = probability;
        mRansacMinInliers = minInliers;
        mRansacMaxIts = maxIterations;
        
        N = mvpMapPoints1.size(); // number of correspondences
        
        mvbInliersi.resize(N);
        
        // Adjust Parameters according to number of correspondences
        float epsilon = (float)mRansacMinInliers/N;
        
        // Set RANSAC iterations according to probability, epsilon, and max iterations
        int nIterations;
        
        if(mRansacMinInliers==N)
            nIterations=1;
        else
            nIterations = ceil(log(1-mRansacProb)/log(1-pow(epsilon,3)));
        
        mRansacMaxIts = max(1,min(nIterations,mRansacMaxIts));
        
        mnIterations = 0;
    }
    
    cv::Mat Sim3Solver::iterate(int nIterations, bool &bNoMore, vector<bool> &vbInliers, int &nInliers)
    {
        bNoMore = false;
        vbInliers = vector<bool>(mN1,false);
        nInliers=0;
        
        if(N<mRansacMinInliers)
        {
            bNoMore = true;
            return cv::Mat();
        }
        
        vector<size_t> vAvailableIndices;
        
        cv::Mat P3Dc1i(3,3,CV_32F);
        cv::Mat P3Dc2i(3,3,CV_32F);
        
        int nCurrentIterations = 0;
        while(mnIterations<mRansacMaxIts && nCurrentIterations<nIterations)
        {
            nCurrentIterations++;
            mnIterations++;
            
            vAvailableIndices = mvAllIndices;
            
            // Get min set of points
            for(short i = 0; i < 3; ++i)
            {
                int randi = DUtils::Random::RandomInt(0, vAvailableIndices.size()-1);
                
                int idx = vAvailableIndices[randi];
                
                mvX3Dc1[idx].copyTo(P3Dc1i.col(i));
                mvX3Dc2[idx].copyTo(P3Dc2i.col(i));
                
                vAvailableIndices[idx] = vAvailableIndices.back();
                vAvailableIndices.pop_back();
            }
            
            computeT(P3Dc1i,P3Dc2i);
            
            CheckInliers();
            
            if(mnInliersi>=mnBestInliers)
            {
                mvbBestInliers = mvbInliersi;
                mnBestInliers = mnInliersi;
                mBestT12 = mT12i.clone();
                mBestRotation = mR12i.clone();
                mBestTranslation = mt12i.clone();
                mBestScale = ms12i;
                
                if(mnInliersi>mRansacMinInliers)
                {
                    nInliers = mnInliersi;
                    for(int i=0; i<N; i++)
                        if(mvbInliersi[i])
                            vbInliers[mvnIndices1[i]] = true;
                    return mBestT12;
                }
            }
        }
        
        if(mnIterations>=mRansacMaxIts)
            bNoMore=true;
        
        return cv::Mat();
    }
    
    cv::Mat Sim3Solver::find(vector<bool> &vbInliers12, int &nInliers)
    {
        bool bFlag;
        return iterate(mRansacMaxIts,bFlag,vbInliers12,nInliers);
    }
    
    void Sim3Solver::centroid(cv::Mat &P, cv::Mat &Pr, cv::Mat &C)
    {
        cv::reduce(P,C,1,CV_REDUCE_SUM);
        C = C/P.cols;
        
        for(int i=0; i<P.cols; i++)
        {
            Pr.col(i)=P.col(i)-C;
        }
    }
    
    void Sim3Solver::computeT(cv::Mat &P1, cv::Mat &P2)
    {
        // Custom implementation of:
        // Horn 1987, Closed-form solution of absolute orientataion using unit quaternions
        
        // Step 1: Centroid and relative coordinates
        
        cv::Mat Pr1(P1.size(),P1.type()); // Relative coordinates to centroid (set 1)
        cv::Mat Pr2(P2.size(),P2.type()); // Relative coordinates to centroid (set 2)
        cv::Mat O1(3,1,Pr1.type()); // Centroid of P1
        cv::Mat O2(3,1,Pr2.type()); // Centroid of P2
        
        centroid(P1,Pr1,O1);
        centroid(P2,Pr2,O2);
        
        // Step 2: Compute M matrix
        
        cv::Mat M = Pr2*Pr1.t();
        
        // Step 3: Compute N matrix
        
        double N11, N12, N13, N14, N22, N23, N24, N33, N34, N44;
        
        cv::Mat N(4,4,P1.type());
        
        N11 = M.at<float>(0,0)+M.at<float>(1,1)+M.at<float>(2,2);
        N12 = M.at<float>(1,2)-M.at<float>(2,1);
        N13 = M.at<float>(2,0)-M.at<float>(0,2);
        N14 = M.at<float>(0,1)-M.at<float>(1,0);
        N22 = M.at<float>(0,0)-M.at<float>(1,1)-M.at<float>(2,2);
        N23 = M.at<float>(0,1)+M.at<float>(1,0);
        N24 = M.at<float>(2,0)+M.at<float>(0,2);
        N33 = -M.at<float>(0,0)+M.at<float>(1,1)-M.at<float>(2,2);
        N34 = M.at<float>(1,2)+M.at<float>(2,1);
        N44 = -M.at<float>(0,0)-M.at<float>(1,1)+M.at<float>(2,2);
        
        N = (cv::Mat_<float>(4,4) << N11, N12, N13, N14,
             N12, N22, N23, N24,
             N13, N23, N33, N34,
             N14, N24, N34, N44);
        
        
        // Step 4: Eigenvector of the highest eigenvalue
        
        cv::Mat eval, evec;
        
        cv::eigen(N,eval,evec); //evec[0] is the quaternion of the desired rotation
        
        cv::Mat vec(1,3,evec.type());
        (evec.row(0).colRange(1,4)).copyTo(vec); //extract imaginary part of the quaternion (sin*axis)
        
        // Rotation angle. sin is the norm of the imaginary part, cos is the real part
        double ang=atan2(norm(vec),evec.at<float>(0,0));
        
        vec = 2*ang*vec/norm(vec); //Angle-axis representation. quaternion angle is the half
        
        mR12i.create(3,3,P1.type());
        
        cv::Rodrigues(vec,mR12i); // computes the rotation matrix from angle-axis
        
        // Step 5: Rotate set 2
        
        cv::Mat P3 = mR12i*Pr2;
        
        // Step 6: Scale
        
        double nom = Pr1.dot(P3);
        cv::Mat aux_P3(P3.size(),P3.type());
        aux_P3=P3;
        cv::pow(P3,2,aux_P3);
        double den = 0;
        
        for(int i=0; i<aux_P3.rows; i++)
        {
            for(int j=0; j<aux_P3.cols; j++)
            {
                den+=aux_P3.at<float>(i,j);
            }
        }
        
        ms12i = nom/den;
        
        // Step 7: Translation
        
        mt12i.create(1,3,P1.type());
        mt12i = O1 - ms12i*mR12i*O2;
        
        // Step 8: Transformation
        
        // Step 8.1 T12
        mT12i = cv::Mat::eye(4,4,P1.type());
        
        cv::Mat sR = ms12i*mR12i;
        
        sR.copyTo(mT12i.rowRange(0,3).colRange(0,3));
        mt12i.copyTo(mT12i.rowRange(0,3).col(3));
        
        // Step 8.2 T21
        
        mT21i = cv::Mat::eye(4,4,P1.type());
        
        cv::Mat sRinv = (1.0/ms12i)*mR12i.t();
        
        sRinv.copyTo(mT21i.rowRange(0,3).colRange(0,3));
        cv::Mat tinv = -sRinv*mt12i;
        tinv.copyTo(mT21i.rowRange(0,3).col(3));
    }
    
    
    void Sim3Solver::CheckInliers()
    {
        vector<cv::Mat> vP1im2, vP2im1;
        Project(mvX3Dc2,vP2im1,mT12i,mK1);
        Project(mvX3Dc1,vP1im2,mT21i,mK2);
        
        mnInliersi=0;
        
        for(size_t i=0; i<mvP1im1.size(); i++)
        {
            cv::Mat dist1 = mvP1im1[i]-vP2im1[i];
            cv::Mat dist2 = vP1im2[i]-mvP2im2[i];
            
            float err1 = dist1.dot(dist1);
            float err2 = dist2.dot(dist2);
            
            if(err1<mvnMaxError1[i] && err2<mvnMaxError2[i])
            {
                mvbInliersi[i]=true;
                mnInliersi++;
            }
            else
                mvbInliersi[i]=false;
        }
    }
    
    
    cv::Mat Sim3Solver::GetEstimatedRotation()
    {
        return mBestRotation.clone();
    }
    
    cv::Mat Sim3Solver::GetEstimatedTranslation()
    {
        return mBestTranslation.clone();
    }
    
    float Sim3Solver::GetEstimatedScale()
    {
        return mBestScale;
    }
    
    void Sim3Solver::Project(const vector<cv::Mat> &vP3Dw, vector<cv::Mat> &vP2D, cv::Mat Tcw, cv::Mat K)
    {
        cv::Mat Rcw = Tcw.rowRange(0,3).colRange(0,3);
        cv::Mat tcw = Tcw.rowRange(0,3).col(3);
        float fx = K.at<float>(0,0);
        float fy = K.at<float>(1,1);
        float cx = K.at<float>(0,2);
        float cy = K.at<float>(1,2);
        
        vP2D.clear();
        vP2D.reserve(vP3Dw.size());
        
        for(size_t i=0, iend=vP3Dw.size(); i<iend; i++)
        {
            cv::Mat P3Dc = Rcw*vP3Dw[i]+tcw;
            float invz = 1/(P3Dc.at<float>(2));
            float x = P3Dc.at<float>(0)*invz;
            float y = P3Dc.at<float>(1)*invz;
            
            vP2D.push_back((cv::Mat_<float>(2,1) << fx*x+cx, fy*y+cy));
        }
    }
    
    void Sim3Solver::FromCameraToImage(const vector<cv::Mat> &vP3Dc, vector<cv::Mat> &vP2D, cv::Mat K)
    {
        float fx = K.at<float>(0,0);
        float fy = K.at<float>(1,1);
        float cx = K.at<float>(0,2);
        float cy = K.at<float>(1,2);
        
        vP2D.clear();
        vP2D.reserve(vP3Dc.size());
        
        for(size_t i=0, iend=vP3Dc.size(); i<iend; i++)
        {
            float invz = 1/(vP3Dc[i].at<float>(2));
            float x = vP3Dc[i].at<float>(0)*invz;
            float y = vP3Dc[i].at<float>(1)*invz;
            
            vP2D.push_back((cv::Mat_<float>(2,1) << fx*x+cx, fy*y+cy));
        }
    }
    
} //namespace ORB_SLAM